# OrtSAE: Orthogonal Sparse Autoencoders Uncover Atomic Features

## Overview

This repository provides the code to reproduce the results from the paper *"OrtSAE: Orthogonal Sparse Autoencoders Uncover Atomic Features"*. It is organized into three main components:

- The `dictionary_learning` folder, containing code for training baseline Sparse Autoencoders (SAEs) and OrtSAE.
- The `SAEBench` folder, which includes code for evaluating trained SAEs using SAEBench.
- The `atomicity_metrics` folder, with code for investigating the atomicity of trained SAEs.

## Installation

1. Ensure you have Python 3.12 installed on your system.
2. Install the required dependencies by running:

   ```bash
   pip install -r requirements.txt
   ```

## Training SAEs

The training code is sourced from the `dictionary_learning` [repository](https://github.com/saprmarks/dictionary_learning.git), with OrtSAE implemented as a new trainer in the `trainers/batch_top_k_ort.py` file. To train SAEs, follow these steps:

1. Configure your training hyperparameters in `training_config.py`. By default, the hyperparameters are configured according to the experimental setup described in the paper.
2. Run the training script with the following command:

   ```bash
   python run_training.py --save_dir ./trained_saes/ort_sae --model_name google/gemma-2-2b --layers 12 --architectures batch_top_k_ort --device cuda:0
   ```

To upload your trained models to Hugging Face, specify the folder containing the trained SAEs and use the `push_to_hf.py` script:

```bash
python push_to_hf.py
```

## Evaluating SAEs with SAEBench

The evaluation code is adapted from the `SAEBench` [repository](https://github.com/adamkarvonen/SAEBench.git). You can use it to compute the following metrics for SAEs trained with dictionary learning: explained variance, KL-divergence score, autointerp score, absorption, SCR, TPP, sparse probing, and RAVEL. To evaluate your SAEs, follow these steps:

1. Upload the trained SAEs to Hugging Face using the `push_to_hf.py` script.
2. Add the Hugging Face ID of your uploaded SAEs to `sae_bench/custom_saes/run_all_evals_dictionary_learning_saes.py`.
3. In `run_all_evals_dictionary_learning_saes.py`, specify which metrics you want to compute.
4. Run the evaluation script:

   ```bash
   python SAEBench/sae_bench/custom_saes/run_all_evals_dictionary_learning_saes.py
   ```

## Evaluating SAEs Atomicity

This section focuses on evaluating the atomicity of SAEs by analyzing composition, clustering rates, and cross-model feature overlap.

To ensure accurate metric estimation, you must first run the core estimation process from SAEBench to filter out dead features during atomicity analysis.

To perform meta-SAE-based composition evaluation, execute the following command:

```bash
python atomicity_metrics/eval_with_meta_sae.py
```

For clustering and cross-model overlap analysis, use the Jupyter notebooks:

- `clustering.ipynb`
- `cross_model_overlap.ipynb`